# -*- coding: utf-8 -*-
"""
Created on Mon Oct  7 15:58:39 2020

@author: Erfan
"""

from __future__ import print_function, division # Grab some 
import numpy as np
from casadi import *
from Parameters import *
import Soil_Functions as sf

Hz=0.7
Nr,Nt,Nz,dr,dt,dz,Np,NN,Nx,Nw,Ny,Nv=circular_parameters()
coordinates=[]
for k in range(Nz):
    for j in range(Nt):
        for i in range(Nr):
            y=[i,j,k]
            coordinates.append(y)
   
def ode_CP(x,u,w,m,Kc,Et0):
    Nr,Nt,Nz,dr,dt,dz,Np,NN,Nx,Nw,Ny,Nv=circular_parameters()
    DeltaT,Nsim=time_parameters()  
    state=x
    dhdt=SX.zeros(NN)  
    for i in range(0,NN):       
        pars_node = i % (Nr*Nt)     
        p_cur=parameter_node(pars_node)     
        CurrentState=state[i]
        C=coordinates[i]
        for index,item in enumerate(C): 
            if index == 0:
                if item == 0:
                    h_singL = state[i+1]
                    p_cen = i+1
                    p_Cen = parameter_node(p_cen % (Nr*Nt))
                    h_sing = CurrentState
                    K_sing = 0.5*(sf.KFun(h_singL,p_Cen) + sf.KFun(h_sing,p_cur))
                    BCr_L = state[int(C[-1]*Np + (0.5*Nt)*Nr+1)]
                    p_rl = int(C[-1]*Np + (0.5*Nt)*Nr+1)
                    p_rL = parameter_node(p_rl % (Nr*Nt))
                    BCr_R = state[int(C[-1]*Np + 1)]
                    p_rr = int(C[-1]*Np + 1)
                    p_rR = parameter_node(p_rr % (Nr*Nt))
                    r_i = item*dr
                    r_iL = 0
                    r_iR = (item+1)*dr  
                elif item == Nr - 1:
                    BCr_L = state[i - 1]
                    p_rl = i-1
                    p_rL = parameter_node(p_rl % (Nr*Nt))
                    BCr_R = CurrentState
                    p_rR = p_cur
                    r_i = item*dr
                    r_iL = (item-1)*dr
                    r_iR = (item+1)*dr
                else:
                    BCr_L = state[i - 1]
                    p_rl = i-1
                    p_rL = parameter_node(p_rl % (Nr*Nt))
                    BCr_R = state[i + 1]
                    p_rr = i+1
                    p_rR = parameter_node(p_rr % (Nr*Nt))
                    r_i = item*dr
                    r_iL = (item-1)*dr
                    r_iR = (item+1)*dr
            elif index == 1:
                if item == 0:
                    BCt_L = state[i+(Np-Nr)]
                    p_tl = i+(Np-Nr)
                    p_tL = parameter_node(p_tl % (Nr*Nt))
                    BCt_R = state[i + Nr]
                    p_tr = i + Nr
                    p_tR = parameter_node(p_tr % (Nr*Nt))
                    r_j=C[0]*dr
                elif item == Nt-1:
                    BCt_L = state[i - Nr]
                    p_tl = i - Nr
                    p_tL = parameter_node(p_tl % (Nr*Nt))
                    BCt_R = state[i-(Np-Nr)]
                    p_tr = i-(Np-Nr)
                    p_tR = parameter_node(p_tr % (Nr*Nt))
                    r_j = C[0]*dr
                else:
                    BCt_L = state[i - Nr]
                    p_tl = i - Nr
                    p_tL = parameter_node(p_tl % (Nr*Nt))
                    BCt_R = state[i + Nr]
                    p_tr = i + Nr
                    p_tR = parameter_node(p_tr % (Nr*Nt))
                    r_j = C[0]*dr
            else:
                if item == 0:
                    BCz_L = CurrentState
                    p_zL = p_cur
                    BCz_U = state[i + Np]
                    p_zu = i + Np
                    p_zU = parameter_node(p_zu % (Nr*Nt))
                elif item == Nz-1:
                    BCz_L = state[i - Np]
                    p_zl = i-Np
                    p_zL = parameter_node(p_zl % (Nr*Nt))
                                            
                    a = 1*C[-2]                     
                    irr1 = if_else(logic_and(mod(m,40)>=0, logic_and(mod(m,40)<=20,mod(m,40)==20-a)),u,0)          
                    irr = if_else(logic_and(mod(m,40)>=21, logic_and(mod(m,40)<=39,mod(m,40)==60-a)),u,irr1)
                                
                else:
                    BCz_L = state[i - Np]
                    p_zl = i-Np
                    p_zL = parameter_node(p_zl % (Nr*Nt))
                    
                    BCz_U = state[i + Np]
                    p_zu = i + Np
                    p_zU = parameter_node(p_zu % (Nr*Nt))
        
        KrL=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCr_L,p_rL))
        KrR=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCr_R,p_rR))
        rrL=0.5*(r_i+r_iL)
        rrR=0.5*(r_i+r_iR)
        DHrL=(state[i] - BCr_L)/dr
        DHrR=(BCr_R - state[i])/dr
        
        KtL=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCt_L,p_tL))
        KtR=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCt_R,p_tR))
        DHtL=(state[i]-BCt_L)/dt
        DHtR=(BCt_R-state[i])/dt
        
#        dist=d[C[-1]]
#        dz_l=dist[0]
#        dz_u=dist[1]
#        dz_i=0.5*(dz_l+dz_u)
        
        KzL=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCz_L,p_zL))
        KzU=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCz_U,p_zU))
        DHzL=(state[i]-BCz_L)/dz
        DHzU=(BCz_U-state[i])/dz

        if r_i==0: 
            Term0=(2*K_sing)*((h_singL-h_sing)/dr**2)
        else:
            Term0=1/(r_i*dr) * (rrR*KrR*DHrR - rrL*KrL*DHrL)
            
        if r_j==0:
            Term1=(2*K_sing)*((h_singL-h_sing)/dr**2)
        else:
            Term1=1/(r_j*dt) * ((KtR/r_j)*DHtR - (KtL/r_j)*DHtL)
        
        if C[-1]==Nz-1:
            Term2=(1.0/dz)*( -irr - KzL*DHzL)
            Term3=(1.0/dz)*( -1*KzL)
        else:
            Term2=(1.0/dz) * (KzU*DHzU - KzL*DHzL)
            Term3=(1.0/dz) * (KzU - KzL)

        Term4=(Kc*Et0)/Hz#Source 0 if the evapotranspiration term is not considered
        Term5=Term0+Term1+Term2+Term3-Term4
        Term6=Term5/sf.CFun(CurrentState,p_cur)
        dhdt[i]=Term6

    DHDT=dhdt+w
    return DHDT

def ode_CP2(x_scaled,u,w,m,Kc,Et0):
    Nr,Nt,Nz,dr,dt,dz,Np,NN,Nx,Nw,Ny,Nv=circular_parameters()
    DeltaT,Nsim=time_parameters()
    
        ## Un normalization
    delta_value = 33.1045
    min_value = -33.1619
    x_unscaled = ( x_scaled * delta_value ) + min_value
    
    
    state=x_unscaled
    dhdt=SX.zeros(NN)
    
    for i in range(0,NN):
        
        pars_node=i % (Nr*Nt)
        
        p_cur=parameter_node(pars_node)
        
        CurrentState=state[i]
        C=coordinates[i]
        for index,item in enumerate(C):
            
            if index == 0:
                if item == 0:
                    h_singL = state[i+1]
                    p_cen = i+1
                    p_Cen = parameter_node(p_cen % (Nr*Nt))
                    
                    h_sing = CurrentState
                    K_sing = 0.5*(sf.KFun(h_singL,p_Cen) + sf.KFun(h_sing,p_cur))
                    
                    
                    BCr_L = state[int(C[-1]*Np + (0.5*Nt)*Nr+1)]
                    p_rl = int(C[-1]*Np + (0.5*Nt)*Nr+1)
                    p_rL = parameter_node(p_rl % (Nr*Nt))
                    
                    BCr_R = state[int(C[-1]*Np + 1)]
                    p_rr = int(C[-1]*Np + 1)
                    p_rR = parameter_node(p_rr % (Nr*Nt))
                    
                    r_i = item*dr
                    r_iL = 0
                    r_iR = (item+1)*dr
                    
                    
                elif item == Nr - 1:
                    BCr_L = state[i - 1]
                    p_rl = i-1
                    p_rL = parameter_node(p_rl % (Nr*Nt))
                    
                    BCr_R = CurrentState
                    p_rR = p_cur
                    r_i = item*dr
                    r_iL = (item-1)*dr
                    r_iR = (item+1)*dr
                
                else:
                    BCr_L = state[i - 1]
                    p_rl = i-1
                    p_rL = parameter_node(p_rl % (Nr*Nt))
                    BCr_R = state[i + 1]
                    p_rr = i+1
                    p_rR = parameter_node(p_rr % (Nr*Nt))
                    r_i = item*dr
                    r_iL = (item-1)*dr
                    r_iR = (item+1)*dr
            
            elif index == 1:
                if item == 0:
                    BCt_L = state[i+(Np-Nr)]
                    p_tl = i+(Np-Nr)
                    p_tL = parameter_node(p_tl % (Nr*Nt))
                    
                    BCt_R = state[i + Nr]
                    p_tr = i + Nr
                    p_tR = parameter_node(p_tr % (Nr*Nt))
                    r_j=C[0]*dr
                    
                    
                elif item == Nt-1:
                    BCt_L = state[i - Nr]
                    p_tl = i - Nr
                    p_tL = parameter_node(p_tl % (Nr*Nt))
                    BCt_R = state[i-(Np-Nr)]
                    p_tr = i-(Np-Nr)
                    p_tR = parameter_node(p_tr % (Nr*Nt))
                    r_j = C[0]*dr
                    
                else:
                    BCt_L = state[i - Nr]
                    p_tl = i - Nr
                    p_tL = parameter_node(p_tl % (Nr*Nt))
                    BCt_R = state[i + Nr]
                    p_tr = i + Nr
                    p_tR = parameter_node(p_tr % (Nr*Nt))
                    r_j = C[0]*dr
            else:
                if item == 0:
                    BCz_L = CurrentState
                    p_zL = p_cur
                    BCz_U = state[i + Np]
                    p_zu = i + Np
                    p_zU = parameter_node(p_zu % (Nr*Nt))
                    
                elif item == Nz-1:
                    BCz_L = state[i - Np]
                    p_zl = i-Np
                    p_zL = parameter_node(p_zl % (Nr*Nt))
                                                    
                    
                    a = 1*C[-2]                     
                    irr1 = if_else(logic_and(mod(m,40)>=0, logic_and(mod(m,40)<=20,mod(m,40)==20-a)),u,0)          
                    irr = if_else(logic_and(mod(m,40)>=21, logic_and(mod(m,40)<=39,mod(m,40)==60-a)),u,irr1)
                
                else:
                    BCz_L = state[i - Np]
                    p_zl = i-Np
                    p_zL = parameter_node(p_zl % (Nr*Nt))
                    
                    BCz_U = state[i + Np]
                    p_zu = i + Np
                    p_zU = parameter_node(p_zu % (Nr*Nt))
        
        
        KrL=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCr_L,p_rL))
        KrR=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCr_R,p_rR))
        rrL=0.5*(r_i+r_iL)
        rrR=0.5*(r_i+r_iR)
        DHrL=(state[i] - BCr_L)/dr
        DHrR=(BCr_R - state[i])/dr
        
        KtL=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCt_L,p_tL))
        KtR=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCt_R,p_tR))
        DHtL=(state[i]-BCt_L)/dt
        DHtR=(BCt_R-state[i])/dt
        
#        dist=d[C[-1]]
#        dz_l=dist[0]
#        dz_u=dist[1]
#        dz_i=0.5*(dz_l+dz_u)
        
        KzL=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCz_L,p_zL))
        KzU=0.5*(sf.KFun(state[i],p_cur)+sf.KFun(BCz_U,p_zU))
        DHzL=(state[i]-BCz_L)/dz        
        DHzU=(BCz_U-state[i])/dz
        

        if r_i==0: 
            Term0=(2*K_sing)*((h_singL-h_sing)/dr**2)
        else:
            Term0=1/(r_i*dr) * (rrR*KrR*DHrR - rrL*KrL*DHrL)
            
            
        
        if r_j==0:
            Term1=(2*K_sing)*((h_singL-h_sing)/dr**2)
        else:
            Term1=1/(r_j*dt) * ((KtR/r_j)*DHtR - (KtL/r_j)*DHtL)
        
        if C[-1]==Nz-1:
            Term2=(1.0/dz)*( -irr - KzL*DHzL)
            Term3=(1.0/dz)*( -1*KzL)
        else:
            Term2=(1.0/dz) * (KzU*DHzU - KzL*DHzL)
            Term3=(1.0/dz) * (KzU - KzL)
        
            

        Term4=(Kc*Et0)/Hz#Source 0 if the evapotranspiration term is not considered
        Term5=Term0+Term1+Term2+Term3-Term4
        Term6=Term5/sf.CFun(CurrentState,p_cur)
        dhdt[i]=Term6

    DHDT=dhdt+w
    return DHDT
